from narrativeNLP.wrappers import build_narrative_model

import pandas as pd
import pickle as pk
from tqdm import tqdm
import json
import glob

import sys
args = sys.argv
c = int(args[1])
print('Number of clusters: {}'.format(c))

# Or manually specify the clusters
# c = 1000

batches = pd.read_pickle('../data/batches.pk')
print('Number of batches:', len(batches))

# Define function to read SRL json

def read_json_file(path):
    with open(path, 'r') as f:
        srl_res = json.load(f)
    return srl_res

# Save the sentences by batch (for later reference)

for batch_id in range(0,29):
    filenames = batches[batch_id]
    list_of_dataframes = []
    for filename in tqdm(filenames):
        if filename == '../data/gpo_sentences/1994-02-01_528.csv':
            print('Empty file.')
        elif filename == '../data/gpo_sentences/1994-01-31_507.csv':
            print('Empty file.')
        elif filename == '../data/gpo_sentences/1994-02-01_549.csv':
            print('Empty file.')
        elif filename == '../data/gpo_sentences/1994-01-31_430.csv':
            print('Empty file.')
        else:
            temp = pd.read_csv(filename)
            temp['doc'] = filename
            list_of_dataframes.append(temp)
    split_sentences = pd.concat(list_of_dataframes, ignore_index=True)
    split_sentences.to_csv('../data/gpo_split_sentences_by_batch/split_sentences_{0}.csv'.format(batch_id), index=False)

del batch_id, filenames, filename, temp, list_of_dataframes, split_sentences

# The narrative model will be trained only on the first batch
# Get the names of the files with the speech text

filenames_sents = batches[0]

# Open all sentences and concatenate to one df

list_of_dataframes = []

for filename in tqdm(filenames_sents):
    temp = pd.read_csv(filename)
    temp['doc'] = filename
    list_of_dataframes.append(temp)

split_sentences = pd.concat(list_of_dataframes, ignore_index=True)
split_sentences = (list(split_sentences['doc']), list(split_sentences['sentence']))

# Load the sentences with the SRL annotations

srl_files = glob.glob('../data/gpo_srl_annotations/srl_res_small_0_*.json')
srl_res = []

for f in srl_files:
    srl_res_batch = read_json_file(f)
    srl_res.extend(srl_res_batch)

for i,srl in enumerate(srl_res):
    if srl is None:
        srl_res[i] = {'words': [], 'verbs': []}

print('Length check:', len(srl_res)==len(split_sentences[1]))

# Load stopwords

with open('../data/dictionaries/congress_stopwords.txt', 'r') as f:
    text = f.readlines()
    congress_stopwords = [x.strip() for x in text]

with open('../data/dictionaries/common_stopwords.txt', 'r') as f:
    text = f.readlines()
    common_stopwords = [x.strip() for x in text]

with open('../data/dictionaries/congress_members.txt', 'r') as f:
    text = f.readlines()
    congress_members = [x.strip() for x in text]

with open('../data/dictionaries/us_states.txt', 'r') as f:
    text = f.readlines()
    us_states = [x.strip() for x in text]

with open('../data/dictionaries/numbers.txt', 'r') as f:
    text = f.readlines()
    numbers = [x.strip() for x in text]

stop_words = common_stopwords + congress_stopwords + congress_members + us_states + numbers

# Build narrative model

print('Building narrative model...')

narrative_model = build_narrative_model(srl_res = srl_res,
                                        sentences = split_sentences[1], # list of sentences
                                        roles_considered = ['ARGO', 'B-V', 'B-ARGM-NEG', 'B-ARGM-MOD', 'ARG1','ARG2'],
                                        roles_with_embeddings = [['ARGO','ARG1','ARG2']],
                                        embeddings_type = 'gensim_keyed_vectors',
                                        embeddings_path = 'glove-wiki-gigaword-300',
                                        n_clusters = [[c]],
                                        verbose = 1,
                                        roles_with_entities = ['ARGO','ARG1','ARG2'],
                                        top_n_entities = 1000,
                                        dimension_reduce_verbs = True,
                                        save_to_disk = '../models/',
                                        max_length = 4,
                                        remove_punctuation = True,
                                        remove_digits = True,
                                        remove_chars = '',
                                        stop_words = stop_words,
                                        lowercase = True,
                                        strip = True,
                                        remove_whitespaces = True,
                                        lemmatize = True,
                                        stem = False,
                                        tags_to_keep = None,
                                        remove_n_letter_words = 1,
                                        progress_bar = False)


print('Saving to disk for specific cluster combination...')

save_to_disk = '../models/'
with open(save_to_disk + 'narrative_model_{}_clusters.pk'.format(c), 'wb') as f:
    pk.dump(narrative_model, f)
